"""
Created on Thu Jul  1 15:08:00 2021

@author: LSY
"""

from __future__ import print_function, division # Grab some handy Python3 stuff.
from Sens_proj_old import SensAnal
import mpctools as mpc
import numpy as np
import casadi
np.set_printoptions(precision=4)   # 输出数据精度 保留小数点后4位
np.set_printoptions(linewidth=400) # 输出不换行

rank_xp = True
Nx = 8 # Number of states
Nu = 4 # Number of inputs
Nymax = Nx
order = Nx # Number of Lie derivative from 0 to order-1
Nw = Nx # Number of process noise
Nv = Nx
Delta = 1/60/2 # Sampling time 20s
Nsim = 50 # Number of sampling time 50

x = casadi.SX.sym('x',Nx,1) 
u = casadi.SX.sym('u',Nu,1) 

# constants
F_01=5;         F_02=10;          F_03=8;         F_04=12; 
T_01=300;       T_02=300;         T_03=300;       T_04=300; 
C01=4;          C02=2;            C03=3;          C04=3.5;
V_1=1;          V_2=3;            V_3=4;          V_4=6;
F1=35;          F2=45;            F3=33;
D_H1=-5e4;      D_H2=-5.2e4;      D_H3=-5e4;
k_10=3e6;       k_20=3e5;         k_30=3e5;
E_1=5e4;        E_2=7.5e4;        E_3=7.53e4;     R=8.314;
C_p=0.231;      rho=1000;         Fr1 = 20;       Fr2 = 10;
alpha_1=D_H1*k_10/(rho*C_p);
alpha_2=D_H2*k_20/(rho*C_p);
alpha_3=D_H3*k_30/(rho*C_p);

xs = np.array([2.78, 363, 2.58, 356, 2.6, 355, 2.6, 392])
us = np.array([1e4, 2e4, 2.5e4, 1e4]).T
lb = 0.5 # lower bound of normalized model: lb*xs
ub = 1.5 # upper bound of normalized model: ub*xs
width = ub-lb 

def FourCSTR(x, u):
    # states    
    x1 = x[0]*(width*xs[0])+lb*xs[0];   x2 = x[1]*(width*xs[1])+lb*xs[1];  x3 = x[2]*(width*xs[2])+lb*xs[2];  x4 = x[3]*(width*xs[3])+lb*xs[3];
    x5 = x[4]*(width*xs[4])+lb*xs[4];   x6 = x[5]*(width*xs[5])+lb*xs[5];  x7 = x[6]*(width*xs[6])+lb*xs[6];  x8 = x[7]*(width*xs[7])+lb*xs[7];
    # inputs
    Qh1 = u[0];     Qh2 = u[1];     Qh3 = u[2];     Qh4 = u[3];
    # state evolution
    dx1=(F_01/V_1*(C01-x1)+Fr1/V_1*(x3-x1)+Fr2/V_1*(x7-x1)-(k_10*np.exp(-E_1/R/x2)+k_20*np.exp(-E_2/R/x2)+k_30*np.exp(-E_3/R/x2))*x1)/(width*xs[0])
    dx2=(F_01/V_1*(T_01-x2)+Fr1/V_1*(x4-x2)+Fr2/V_1*(x8-x2)-(alpha_1*np.exp(-E_1/R/x2)+alpha_2*np.exp(-E_2/R/x2)+alpha_3*np.exp(-E_3/R/x2))*x1+Qh1/rho/C_p/V_1)/(width*xs[1])
    dx3=(F1/V_2*(x1-x3)+F_02/V_2*(C02-x3)-(k_10*np.exp(-E_1/R/x4)+k_20*np.exp(-E_2/R/x4)+k_30*np.exp(-E_3/R/x4))*x3)/(width*xs[2])
    dx4=(F1/V_2*(x2-x4)+F_02/V_2*(T_02-x4)-(alpha_1*np.exp(-E_1/R/x4)+alpha_2*np.exp(-E_2/R/x4)+alpha_3*np.exp(-E_3/R/x4))*x3+Qh2/rho/C_p/V_2)/(width*xs[3])
    dx5=((F2-Fr1)/V_3*(x3-x5)+F_03/V_3*(C03-x5)-(k_10*np.exp(-E_1/R/x6)+k_20*np.exp(-E_2/R/x6)+k_30*np.exp(-E_3/R/x6))*x5)/(width*xs[4])
    dx6=((F2-Fr1)/V_3*(x4-x6)+F_03/V_3*(T_03-x6)-(alpha_1*np.exp(-E_1/R/x6)+alpha_2*np.exp(-E_2/R/x6)+alpha_3*np.exp(-E_3/R/x6))*x5+Qh3/rho/C_p/V_3)/(width*xs[5])
    dx7=(F3/V_4*(x5-x7)+F_04/V_4*(C04-x7)-(k_10*np.exp(-E_1/R/x8)+k_20*np.exp(-E_2/R/x8)+k_30*np.exp(-E_3/R/x8))*x7)/(width*xs[6])
    dx8=(F3/V_4*(x6-x8)+F_04/V_4*(T_04-x8)-(alpha_1*np.exp(-E_1/R/x8)+alpha_2*np.exp(-E_2/R/x8)+alpha_3*np.exp(-E_3/R/x8))*x7+Qh4/rho/C_p/V_4)/(width*xs[7])
    # vector
    dx = [dx1, dx2, dx3, dx4, dx5, dx6, dx7, dx8]
    return np.array(dx)

# Turn into casadi function and simulator.
F = mpc.getCasadiFunc(FourCSTR,[Nx,Nu],["x","u"],"cstrmodel")
F_rk4 = mpc.getCasadiFunc(FourCSTR,[Nx,Nu],["x","u"],
                                     "ode_rk4",rk4=True,Delta=Delta,M=1)

cstr4 = mpc.DiscreteSimulator(FourCSTR, Delta, [Nx,Nu], ["x","u"])

sigma_w = 0.001
sigma_v = 0.001
w = sigma_w*np.random.randn(Nsim,Nw)
v = sigma_v*np.random.randn(Nsim,Nv)

usim = np.zeros((Nsim+1,Nu))
usim[:,0] = us[0]
usim[:,1] = us[1]
usim[:,2] = us[2]
usim[:,3] = us[3]

# generate data x and y
xsim = np.zeros((Nsim+1,Nx))
xsim[0,:] = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
for t in range(Nsim):
    xsim[t+1,:] = cstr4.sim(xsim[t,:],usim[t,:]) + w[t,:]
    
M_known = [0,1,2,3,4,5,6,7]
M_remain = M_known
M_remove = np.zeros((1,1))
Ny = Nx # Nx:sensor remove      2:注释begin remove
##### initial
print('Consider all the measurements:', M_known)
ysim = np.zeros((Nsim, Ny))
def measurement(x):
    C_all = np.eye(8)
    C = C_all[M_remain,:]
    y = np.dot(C,x)
    return y
for t in range(Nsim):
    C_all = np.eye(8)
    C = C_all[M_remain,:]
    ysim[t,:] = measurement(xsim[t,:])+np.dot(C,v[t,:]) # Get zero-noise measurement.
H = mpc.getCasadiFunc(measurement,[Nx],["x"],"measurement")
Flag = SensAnal(F_rk4,H,xsim[0:Nsim,:],ysim[0:Nsim,:],usim[0:Nsim-1,:],sigma_w,sigma_v,Nsim-1,rank_xp)   
Rank_Sen = Flag[0]
print('Rank of Sensiti:', Rank_Sen)
Sen_state_sort = Flag[1:-1]
print("Sen state sort:",Sen_state_sort)
Sen_deg = Flag[-1]
print("Sen degree:",Sen_deg)


#print("-----------------begin remove---------------")
for k in range(Nx):
    Rank_Sen_remain_max = np.zeros((1,1))
    print(k+1,"-th remove: ")
    M_remain_try = []
    M_remain_copy = []
    Ny = Ny - 1
    Sen_deg_max = 0
    for j in range(len(M_remain)):
        i = M_remain[j]
        M_remove_try = i;
        print('try to remove', M_remove_try)
        M_remain_try[:] = M_remain
        M_remain_try.remove(i)
        print('remian these measurements', M_remain_try)
            
        ysim = np.zeros((Nsim, Ny))
        def measurement(x):
            C_all = np.eye(8)
            C = C_all[M_remain_try,:]
            y = np.dot(C,x)
            return y
        
        for t in range(Nsim):
            C_all = np.eye(8)
            C = C_all[M_remain_try,:]
            ysim[t,:] = measurement(xsim[t,:])+np.dot(C,v[t,:]) # Get zero-noise measurement.
        H = mpc.getCasadiFunc(measurement,[Nx],["x"],"measurement")
        Flag = SensAnal(F_rk4,H,xsim[0:Nsim,:],ysim[0:Nsim,:],usim[0:Nsim-1,:],sigma_w,sigma_v,Nsim-1,rank_xp)
        
        Rank_Sen_remain = Flag[0]
        print('Rank of Sensiti:', Rank_Sen_remain)
        Sen_state_order = Flag[1:-1]
        print("Sen state order:",Sen_state_order)
        Sen_deg = Flag[-1]
        print("Sen degree: %.5g" % Sen_deg)
        
        if Rank_Sen_remain_max < Rank_Sen_remain:
            Rank_Sen_remain_max[0,0] = Rank_Sen_remain
            
        if Sen_deg >= Sen_deg_max:
           Sen_deg_max = Sen_deg
           M_remove = M_remove_try
           M_remain_copy[:] = M_remain_try
        print("----------------------------------------")
    if Rank_Sen_remain_max < Nx:
        print("stop trying")
        break
    else:
        print("new removed:",M_remove)
        M_remain = M_remain_copy
        print("remians:",M_remain)
print("Finially remians:",M_remain)
        
